﻿using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Security;
using Belikov.GenuineChannels.Security;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Zyan.SafeDeserializationHelpers;

namespace GenuineChannels.UnitTests.UnitTests
{
	[TestClass]
	public class VulnerabilityTests : RemoteServerTestBase
	{
		#region Test Setup

		public VulnerabilityTests()
			: base(GcChannelType.GTCP, () => new Service(), IpVersion.IPv4)
		{
		}

		protected override string ServiceAddress => "localhost"; // Should work for both IPv4 systems, and IPv6, incl. dual stack systems
		protected override int Port => 8732;
		protected override string GcChannelName => "vtcp";

		#endregion

		#region Service def.

		internal class Service : MarshalByRefObject, IService
		{
			public string TestMethod(object payload)
			{
				return "Pwned";
			}
		}

		public interface IService
		{
			string TestMethod(object payload);
		}

		#endregion

#if FRM40
		private static Comparer<T> CreateComparer<T>(Comparison<T> comparison)
		{
			return Comparer<T>.Create(comparison);
		}
#else
		private static Comparer<T> CreateComparer<T>(Comparison<T> comparison)
		{
			return new ComparisonComparer<T>(comparison);
		}

		[Serializable]
		public class ComparisonComparer<T> : Comparer<T>
		{
			private readonly Comparison<T> _comparison;
			public ComparisonComparer(Comparison<T> comparison) => _comparison = comparison;
			public override int Compare(T x, T y) => _comparison(x, y);
		}

		[Serializable]
		public class SortedSet<T>
		{
			private IComparer<T> Comparer { get; }
			private List<T> Items { get; } = new List<T>();
			public SortedSet(IComparer<T> comparer) => Comparer = comparer;
			public void Add(T item) => Items.Add(item);
		}
#endif

		// Taken from https://github.com/pwntester/ysoserial.net/blob/master/ysoserial/Generators/TypeConfuseDelegateGenerator.cs#L26
		public static object GetTypeConfuseDelegate(string cmd)
		{
			var da = new Comparison<string>(string.Compare);
			var d = (Comparison<string>)Delegate.Combine(da, da);
			var comp = CreateComparer(d);
			var set = new SortedSet<string>(comp);
			set.Add("cmd");
			set.Add("/c " + cmd);

			var fi = typeof(MulticastDelegate).GetField("_invocationList", BindingFlags.NonPublic | BindingFlags.Instance);
			object[] invoke_list = d.GetInvocationList();
			// Modify the invocation list to add Process::Start(string, string)
			invoke_list[1] = new Func<string, string, Process>(Process.Start);
			fi.SetValue(d, invoke_list);

			return set;
		}

		[TestMethod, ExpectedException(typeof(SecurityException))]
		public void RemoteInvocation()
		{
			// note: localhost url should now work, although it may be resolved to IPv6 address
			var proxy = (IService)Activator.GetObject(typeof(IService), ServiceUri);
			var maliciousPayload = GetTypeConfuseDelegate("calc");
			var result = proxy.TestMethod(maliciousPayload);

			Assert.AreEqual("Pwned", result);
		}

		[TestMethod, ExpectedException(typeof(SecurityException))]
		public void RemoteInvocationWithPostfix()
		{
			// note: localhost url should now work, although it may be resolved to IPv6 address
			var proxy = (IService)Activator.GetObject(typeof(IService), ServiceUriRem);
			var maliciousPayload = GetTypeConfuseDelegate("calc");
			var result = proxy.TestMethod(maliciousPayload);

			Assert.AreEqual("Pwned", result);
		}

		[TestMethod, ExpectedException(typeof(SecurityException))]
		public void RemoteInvocationWithCompression()
		{
			// test compression
			var parameters = new SecuritySessionParameters(
				SecuritySessionServices.DefaultContext.Name,
				SecuritySessionAttributes.EnableCompression,
				TimeSpan.FromSeconds(5));

			// note: localhost url should now work, although it may be resolved to IPv6 address
			var proxy = (IService)Activator.GetObject(typeof(IService), ServiceUri);

			using (new SecurityContextKeeper(parameters))
			{
				var maliciousPayload = GetTypeConfuseDelegate("calc");
				var result = proxy.TestMethod(maliciousPayload);

				Assert.AreEqual("Pwned", result);
			}
		}

		[TestMethod, ExpectedException(typeof(SecurityException))]
		public void RemoteInvocationWithCompressionAndPostfix()
		{
			// test compression
			var parameters = new SecuritySessionParameters(
				SecuritySessionServices.DefaultContext.Name,
				SecuritySessionAttributes.EnableCompression,
				TimeSpan.FromSeconds(5));

			// note: localhost url should now work, although it may be resolved to IPv6 address
			var proxy = (IService)Activator.GetObject(typeof(IService), ServiceUriRem);

			using (new SecurityContextKeeper(parameters))
			{
				var maliciousPayload = GetTypeConfuseDelegate("calc");
				var result = proxy.TestMethod(maliciousPayload);

				Assert.AreEqual("Pwned", result);
			}
		}
	}
}
